# -*- coding: utf-8 -*-

import sqlalchemy as sa
from sqlalchemy import event

from bopress import metas, options
from bopress.log import Logger
from bopress.model import Users, ObjectPermissions, UserObjectPermissions, UserRole
from bopress.orm import SessionFactory, get_ref_column
from bopress.settings import TABLE_NAME_PREFIX
from bopress.utils import Utils, DataResult

__author__ = 'yezang'


def login(user_key, user_pass):
    """
    用户登录
    :param user_key: value in [user_login,user_email,user_mobile_phone]
    :param user_pass: password
    :return: bopress.model.Users
    """
    s = SessionFactory.session()
    u = s.query(Users).filter(
        sa.or_(Users.user_login == user_key, Users.user_email == user_key, Users.user_mobile_phone == user_key)) \
        .filter(Users.user_pass == Utils.md5(user_pass)) \
        .filter(Users.user_status == 1).one_or_none()
    r = DataResult()
    if u:
        r.success = True
        r.data = u
        return r
    r.success = False
    r.message = "401"
    return r


def get(user_id=None, user_login="", user_email="", user_mobile_phone=""):
    """
    得到用户, 只要以下任意参数赋值
    :param user_id: user id
    :param user_login: user login
    :param user_email: email
    :param user_mobile_phone: mobile phone
    :return: bopress.model.Users
    """
    s = SessionFactory.session()
    if user_id and user_id > 0:
        return s.query(Users).filter(Users.user_id == user_id).one_or_none()
    if user_login:
        return s.query(Users).filter(Users.user_login == user_login).one_or_none()
    if user_email:
        return s.query(Users).filter(Users.user_email == user_email).one_or_none()
    if user_mobile_phone:
        return s.query(Users).filter(Users.user_mobile_phone == user_mobile_phone).one_or_none()
    return None


def get_user_id(user_id=None, user_login="", user_email="", user_mobile_phone=""):
    """
    得到用户ID, 只要以下任意参数赋值
    :param user_id: user id
    :param user_login: user login
    :param user_email: email
    :param user_mobile_phone: mobile phone
    :return: bopress.model.Users
    """
    s = SessionFactory.session()
    if user_login:
        return s.query(Users.user_id).filter(Users.user_login == user_login).scalar()
    if user_email:
        return s.query(Users.user_id).filter(Users.user_email == user_email).scalar()
    if user_mobile_phone:
        return s.query(Users.user_id).filter(Users.user_mobile_phone == user_mobile_phone).scalar()
    return user_id


def save_user_roles(user_id, roles, create=False):
    if not roles:
        return
    s = SessionFactory.session()
    if not create:
        s.query(UserRole).filter(UserRole.user_id == user_id).delete()
    for role in roles:
        s.add(UserRole(user_id=user_id, role_name=role))
    s.commit()


def registry(user_login, user_pass, user_email="", user_mobile_phone="", user_nicename="", display_name="",
             user_status=0):
    """
    注册新用户
    :param user_login: user login
    :param user_pass: password
    :param user_email: email
    :param user_mobile_phone: mobile phone
    :param user_nicename: nice name
    :param display_name: display name
    :param user_status: status 0 or 1
    :return: DataResult
    """
    r = DataResult()
    s = SessionFactory.session()
    num = s.query(Users).filter(Users.user_login == user_login).count()
    if num > 0:
        r.success = False
        r.message = "UserLoginExists"
        return r
    if user_mobile_phone:
        num = s.query(Users).filter(Users.user_mobile_phone == user_mobile_phone).count()
        if num > 0:
            r.success = False
            r.message = "UserMobilePhoneExists"
            return r
    if user_email:
        num = s.query(Users).filter(Users.user_email == user_email).count()
        if num > 0:
            r.success = False
            r.message = "UserEmailExists"
            return r
    u = Users()
    try:
        u.user_login = user_login
        u.user_pass = Utils.md5(user_pass)
        u.user_email = user_email
        u.user_mobile_phone = user_mobile_phone
        u.user_nicename = user_nicename
        u.display_name = display_name
        u.user_status = user_status
        u.user_activation_key = Utils.uniq_index()
        u.user_registered = Utils.current_datetime()
        s.add(u)
        s.commit()
        site_options = options.get_site_options()
        default_roles = set(site_options.get("default_role", False))
        save_user_roles(u.user_id, default_roles, True)
        metas.save_user_metas(u.user_id, "bo_super", False)
        metas.save_user_metas(u.user_id, "bo_roles", default_roles)
        metas.save_user_metas(u.user_id, "bo_description", "")
        metas.save_user_metas(u.user_id, "bo_gravatar", "")
        r.success = True
        r.message = ""
        r.data = u
        return r
    except Exception as e:
        Logger.exception(e)
        r.success = False
        r.message = "500"
        r.data = None
        return r


def add_capability(name, group="general", description=""):
    """
    添加权限点
    :param name: 名称 唯一
    :param group: 分组 便于权限分配
    :param description: 此权限的用处说明
    :return:
    """
    c = options.get_options("bo_capabilities")
    if not c:
        return
    c[name] = (group, description)
    options.save_options("bo_capabilities", c)


def add_role(name, display_name="", capabilities=None):
    """
    添加角色
    :param name: 角色名称
    :param display_name: 角色显示名称
    :param capabilities: 权限点, list type
    :return:
    """
    if type(capabilities) is not list:
        print("capabilities require list type")
        return
    c = options.get_options("bo_roles")
    if not c:
        return
    if not display_name:
        display_name = name
    if not capabilities:
        capabilities = set()
    c[name] = [display_name, capabilities]
    options.save_options("bo_roles", c)


class Auth(object):
    def __init__(self, user_id=None, user_login="", user_email="", user_mobile_phone=""):
        self._user_id = get_user_id(user_id, user_login, user_email, user_mobile_phone)

    def has_perms(self, caps=""):
        """

        :param caps: str or list or set
        :return: bool
        """
        if not self._user_id:
            return False
        if caps is str:
            caps = [caps]
        perms = set()
        roles = metas.get_user_metas(self._user_id, "bo_roles")
        platform_roles = options.get_options("bo_roles")
        for role in platform_roles:
            if role in roles:
                perms |= set(platform_roles[role][1])
        caps = set(caps)
        if caps.issubset(perms):
            return True
        return False

    def get_perms(self):
        perms = set()
        if not self._user_id:
            return perms
        roles = metas.get_user_metas(self._user_id, "bo_roles")
        platform_roles = options.get_options("bo_roles")
        for role in platform_roles:
            if role in roles:
                perms |= set(platform_roles[role][1])
        return perms

    def is_super(self):
        if not self._user_id:
            return False
        return metas.get_user_metas(self._user_id, "bo_super")


class ObjectPermissionChecker(object):
    def __init__(self, user_id=None, user_login="", user_email="", user_mobile_phone=""):
        self._user_id = get_user_id(user_id, user_login, user_email, user_mobile_phone)

    @staticmethod
    def klass_name(model_object):
        return "{0}.{1}".format(model_object.__class__.__module__, model_object.__class__.__name__)

    def _create_user_object_permission(self, m):
        s = SessionFactory.session()
        uop = s.query(UserObjectPermissions) \
            .filter(UserObjectPermissions.cls_name == m.cls_name) \
            .filter(UserObjectPermissions.perm == m.perm) \
            .filter(UserObjectPermissions.user_id == self._user_id).scalar()
        if uop:
            return uop
        s = SessionFactory.session()
        s.add(m)
        s.commit()
        return m

    def assign(self, caps="", model_object=None):
        """

        :param caps: str or list
        :param model_object: instance with pk value.
        :return:
        """
        if not caps or not model_object:
            return
        # TODO check caps in platform caps?
        if type(caps) is str:
            caps = [caps]
        pk = getattr(model_object, get_ref_column(type(model_object)).key)
        if not pk:
            # TODO throw a custom exception
            Logger.error("Model instance pk value None!")
            return
        klass_full_name = ObjectPermissionChecker.klass_name(model_object)
        s = SessionFactory.session()
        for cap in set(caps):
            uop = UserObjectPermissions()
            uop.user_id = self._user_id
            uop.cls_name = klass_full_name
            uop.perm = cap
            m = self._create_user_object_permission(uop)
            op = ObjectPermissions()
            op.object_id = str(pk)
            op.user_object_permission_id = m.user_object_permission_id
            s.add(op)
        s.commit()

    def has_perms(self, caps="", model_object=None, use_role=False):
        """

        :param use_role: role perms
        :param caps: str or list
        :param model_object: instance with pk value.
        :return: bool
        """
        perms = self.get_perms(model_object, use_role)
        if type(caps) is str:
            return caps in perms
        return perms.issuperset(set(caps))

    def has_role_perms(self, caps="", model_object=None):
        self.has_perms(caps, model_object, True)

    def _get_object_perms(self, model_object=None, use_role=False):
        """
        得到对象权限集合
        :param model_object: 对象实例
        :param use_role: 是否返回角色权限
        :return: 权限集合
        """
        perms = set()
        if not model_object:
            return perms
        pk = getattr(model_object, get_ref_column(type(model_object)).key)
        if not pk:
            # TODO throw a custom exception
            Logger.error("Model instance pk value None!")
            return perms
        klass_full_name = ObjectPermissionChecker.klass_name(model_object)
        s = SessionFactory.session()
        q = s.query(UserObjectPermissions.perm) \
            .select_from(ObjectPermissions).outerjoin(UserObjectPermissions) \
            .filter(UserObjectPermissions.cls_name == klass_full_name) \
            .filter(ObjectPermissions.object_id == str(pk))
        if not use_role:
            q = q.filter(UserObjectPermissions.user_id == self._user_id)
        else:
            q = q.filter(UserObjectPermissions.user_id.is_(None))
        for perm in q.all():
            perms.add(perm[0])
        return perms

    @staticmethod
    def _get_object_all_perms(model_object=None):
        """
        得到模型对象的所有权限集合
        :param model_object: 模型对象
        :return: 权限集合
        """
        perms = set()
        if not model_object:
            return perms
        pk = getattr(model_object, get_ref_column(type(model_object)).key)
        if not pk:
            # TODO throw a custom exception
            Logger.error("Model instance pk value None!")
            return perms
        klass_full_name = ObjectPermissionChecker.klass_name(model_object)
        s = SessionFactory.session()
        q = s.query(UserObjectPermissions.perm) \
            .select_from(ObjectPermissions).outerjoin(UserObjectPermissions) \
            .filter(UserObjectPermissions.cls_name == klass_full_name) \
            .filter(ObjectPermissions.object_id == str(pk))
        for perm in q.all():
            perms.add(perm[0])
        return perms

    def get_perms(self, model_object=None, use_role=False):
        """
        得到对象权限集合
        ..testnode:
            user_obj_perms = self.get_perms(Post(post_id=1))
        :param model_object: 对象实例
        :param use_role: `True`返回角色权限, `False`返回用户权限
        :return: 权限集合
        """
        auth = Auth(self._user_id)
        if auth.is_super():
            # return all.
            return self._get_object_all_perms(model_object)
        object_perms = self._get_object_perms(model_object, use_role)
        user_perms = auth.get_perms()
        return object_perms & user_perms

    def remove_perms(self, caps="", model_object=None, use_role=False):
        pk = getattr(model_object, get_ref_column(type(model_object)).key)
        if not pk:
            # TODO throw a custom exception
            Logger.error("Model instance pk value None!")
            return
        klass_full_name = ObjectPermissionChecker.klass_name(model_object)
        s = SessionFactory.session()
        q = s.query(ObjectPermissions).outerjoin(UserObjectPermissions) \
            .filter(UserObjectPermissions.cls_name == klass_full_name) \
            .filter(ObjectPermissions.object_id == str(pk))
        if type(caps) is str:
            q = q.filter(UserObjectPermissions.perm == caps)
        else:
            q = q.filter(UserObjectPermissions.perm.in_(caps))
        if not use_role:
            q = q.filter(UserObjectPermissions.user_id == self._user_id)
        else:
            q = q.filter(UserObjectPermissions.user_id.is_(None))
        effect_num = 0
        for op in q.all():
            s.delete(op)
            effect_num += 1
        s.commit()
        return effect_num

    def get_roles(self, model_object=None, use_subquery=True):
        """
        得到拥有此对象访问权限的所有角色集合
        :param model_object:
        :param use_subquery:
        :return:
        """
        pk = getattr(model_object, get_ref_column(type(model_object)).key)
        if not pk:
            # TODO throw a custom exception
            Logger.error("Model instance pk value None!")
            return None
        s = SessionFactory.session()
        role_caps = self._get_object_perms(model_object, True)
        platform_roles = options.get_options("bo_roles")
        roles = set()
        for k, v in platform_roles.items():
            same_caps = set(v[1]) & role_caps
            if same_caps:
                roles.add(k)
        q = s.query(UserRole).filter(UserRole.role_name.in_(roles))
        if use_subquery:
            return q.subquery()
        return q.all()

    def get_users(self, model_object=None, use_role=False, use_subquery=True):
        """
        得到拥有此对象访问权限的所有用户集合
        :param model_object:
        :param use_role:
        :param use_subquery:
        :return:
        """
        pk = getattr(model_object, get_ref_column(type(model_object)).key)
        if not pk:
            # TODO throw a custom exception
            Logger.error("Model instance pk value None!")
            return None
        klass_full_name = ObjectPermissionChecker.klass_name(model_object)
        s = SessionFactory.session()
        if not use_role:
            q = s.query(Users) \
                .select_from(ObjectPermissions).outerjoin(UserObjectPermissions) \
                .outerjoin(Users, Users.user_id == UserObjectPermissions.user_id) \
                .filter(UserObjectPermissions.cls_name == klass_full_name) \
                .filter(ObjectPermissions.object_id == str(pk))
            q = q.filter(UserObjectPermissions.user_id.isnot(None)).group_by(UserObjectPermissions.user_id)
            if use_subquery:
                return q.subquery()
            return q.all()
        else:
            roles_subq = self.get_roles(model_object)
            q = s.query(Users).select_from(roles_subq).outerjoin(Users, roles_subq.c.user_id == Users.user_id)
            if use_subquery:
                return q.subquery()
            return q.all()

    def get_objects(self, caps="", klass=None, union=True, use_role=False, any_perm=False):
        """
        获取受权限保护的对象子查询,这个子查询仅返回一列``object_id``

            example::
                p = ObjectPermissionChecker(user_id=2)
                subq = p.get_objects(["edit_post", "r_edit_post"], Posts)
                s = SessionFactory.session()
                q = s.query(Posts.title, Posts.content)\
                        .select_from(subq).outerjoin(Posts, Posts.post_id == subq.c.object_id)
                print(q.all())

        :param caps: `str` or `list`, 权限项或者权限项数组
        :param class klass: 数据模型类
        :param bool union: ``True`` 返回用户拥有的`klass`全部数据对象,自己创建的或者他人创建的, ``False``返回取决后面两参数值.
        :param bool use_role: ``False`` 返回用户创建的`klass`全部数据对象, ``True`` 返回用户拥有的角色所能访问的全部数据对象.
        :param bool any_perm: ``True`` 接受任意权限,即查询不经过权限项过滤
        :return: ``subquery`` or ``None``
        """
        if not caps or not klass:
            return None
        if not self._user_id:
            return None
        pk_column = get_ref_column(klass)
        klass_full_name = "{0}.{1}".format(klass.__module__, klass.__name__)
        s = SessionFactory.session()
        auth = Auth(self._user_id)
        if auth.is_super():
            # return all.
            return s.query(sa.cast(ObjectPermissions.object_id, pk_column.type).label("object_id")) \
                .outerjoin(UserObjectPermissions).filter(UserObjectPermissions.cls_name == klass_full_name) \
                .subquery()
        user_perms = auth.get_perms()
        if type(caps) is str:
            perms = {caps}
        else:
            perms = set(caps)
        valid_perms = user_perms & perms
        if not valid_perms:
            return None
        if union:
            role_valid_perms = ObjectPermissionChecker._valid_role_perms(valid_perms)
            all_q = s.query(sa.cast(ObjectPermissions.object_id, pk_column.type).label("object_id")) \
                .outerjoin(UserObjectPermissions) \
                .filter(UserObjectPermissions.cls_name == klass_full_name)
            if not any_perm:
                # current user no role caps
                if not role_valid_perms:
                    # only one cap
                    if len(valid_perms) == 1:
                        all_q = all_q.filter(sa.and_(UserObjectPermissions.user_id == self._user_id,
                                                     UserObjectPermissions.perm == list(valid_perms)[0]))
                    else:
                        all_q = all_q.filter(sa.and_(UserObjectPermissions.user_id == self._user_id,
                                                     UserObjectPermissions.perm.in_(valid_perms)))
                else:
                    all_q = all_q.filter(sa.or_(sa.and_(UserObjectPermissions.user_id == self._user_id,
                                                        UserObjectPermissions.perm.in_(valid_perms)),
                                                sa.and_(UserObjectPermissions.user_id.is_(None),
                                                        UserObjectPermissions.perm.in_(role_valid_perms))))
            return all_q.subquery()
        if not use_role:
            # user objects
            user_q = s.query(sa.cast(ObjectPermissions.object_id, pk_column.type).label("object_id")) \
                .outerjoin(UserObjectPermissions) \
                .filter(UserObjectPermissions.cls_name == klass_full_name) \
                .filter(UserObjectPermissions.user_id == self._user_id)
            if not any_perm:
                # only one cap
                if len(valid_perms) == 1:
                    user_q = user_q.filter(UserObjectPermissions.perm == list(valid_perms)[0])
                else:
                    user_q = user_q.filter(UserObjectPermissions.perm.in_(valid_perms))
            return user_q.subquery()
        else:
            # accept caps startswith `r_`
            valid_perms = ObjectPermissionChecker._valid_role_perms(valid_perms)
            if not valid_perms:
                return None
            role_q = s.query(sa.cast(ObjectPermissions.object_id, pk_column.type).label("object_id")) \
                .outerjoin(UserObjectPermissions) \
                .filter(UserObjectPermissions.cls_name == klass_full_name) \
                .filter(UserObjectPermissions.user_id.is_(None))
            if not any_perm:
                # only one cap
                if len(valid_perms) == 1:
                    role_q = role_q.filter(UserObjectPermissions.perm == list(valid_perms)[0])
                else:
                    role_q = role_q.filter(UserObjectPermissions.perm.in_(valid_perms))

            return role_q.subquery()

    @staticmethod
    def _valid_role_perms(perms):
        tmp = set()
        for p in perms:
            if p.startswith("r_"):
                tmp.add(p)
        return tmp


def remove_perms_for_bulk_delete(session, entity_cls, object_id_set):
    """
    批量删除对象权限
    :param session: `.orm.SessionFactory`
    :param class entity_cls: 实体类
    :param object_id_set: list or set
    """
    klass_full_name = "{0}.{1}".format(entity_cls.__module__, entity_cls.__name__)
    a = session.query(ObjectPermissions.object_permission_id) \
        .outerjoin(UserObjectPermissions) \
        .filter(UserObjectPermissions.cls_name == klass_full_name) \
        .filter(ObjectPermissions.object_id.in_(object_id_set)).all()
    items = [i[0] for i in a]
    if items:
        session.query(ObjectPermissions).filter(ObjectPermissions.object_permission_id.in_(set(items))).delete(False)


def remove_perms_on_delete():
    """
        单对象权限关联删除实体类包装器

        用下面的方式删除对象时触发删除对象关联的权限::

            s = SessionFactory.session()
            s.delete(someobject)
            s.commit()

        如果是批量删除::

            s.query(Posts).filter(Posts.post_id == 2).delete()

        则不会触发, 批量删除可以使用辅助函数 `remove_perms_for_bulk_delete`::

            remove_perms_for_bulk_delete(s, Posts, [2])

    :return:
    """

    def wrapp(cls):
        @event.listens_for(cls, 'after_delete')
        def on_delete(mapper, connection, target):
            entity_cls = type(target)
            klass_full_name = "{0}.{1}".format(entity_cls.__module__, entity_cls.__name__)
            pk_val = getattr(target, get_ref_column(entity_cls).key)
            t_object_permissions = "{0}objectpermissions".format(TABLE_NAME_PREFIX)
            t_user_object_permissions = "{0}userobjectpermissions".format(TABLE_NAME_PREFIX)
            sql = """
            SELECT {0}.object_permission_id FROM {1} LEFT JOIN {2}
             ON {3}.user_object_permission_id={4}.user_object_permission_id
              WHERE {5}.cls_name=? AND {6}.object_id=?
            """.format(t_object_permissions, t_object_permissions, t_user_object_permissions,
                       t_user_object_permissions, t_object_permissions,
                       t_user_object_permissions, t_object_permissions)
            r = connection.execute(sql, klass_full_name, pk_val)
            arr = [i[0] for i in r.fetchall()]
            del_sql = """
            DELETE FROM {0} WHERE {1}.object_permission_id=?""" \
                .format(t_object_permissions, t_object_permissions)
            for object_permission_id in arr:
                connection.execute(del_sql, object_permission_id)

        return cls

    return wrapp
